{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import warnings\n",
    "import numpy as np\n",
    "import os\n",
    "import cv2\n",
    "import tqdm\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "device = 'cuda'\n",
    "shape = (H,W,C) = (256,256,3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision import models\n",
    "raft = models.optical_flow.raft_small(weights = 'Raft_Small_Weights.C_T_V2').eval().to(device)\n",
    "\n",
    "def get_flow(img1, img2):\n",
    "    img1_t = torch.from_numpy(img1/255.0).permute(-1,0,1).unsqueeze(0).float().to(device)\n",
    "    img2_t = torch.from_numpy(img2/255.0).permute(-1,0,1).unsqueeze(0).float().to(device)\n",
    "    flow = raft(img1_t,img2_t)[-1].detach().cpu().permute(0,2,3,1).squeeze(0).numpy()\n",
    "    return flow\n",
    "\n",
    "def show_flow(flow):\n",
    "    hsv_mask = np.zeros(shape= flow.shape[:-1] +(3,),dtype = np.uint8)\n",
    "    hsv_mask[...,1] = 255\n",
    "    mag, ang = cv2.cartToPolar(flow[...,0], flow[...,1],angleInDegrees=True)\n",
    "    hsv_mask[:,:,0] = ang /2 \n",
    "    hsv_mask[:,:,2] = cv2.normalize(mag,None,0,255,cv2.NORM_MINMAX)\n",
    "    rgb = cv2.cvtColor(hsv_mask,cv2.COLOR_HSV2RGB)\n",
    "    return(rgb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  9%|▊         | 2/23 [00:25<04:22, 12.52s/it]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[8], line 11\u001b[0m\n\u001b[0;32m      9\u001b[0m flows \u001b[38;5;241m=\u001b[39m []\n\u001b[0;32m     10\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[1;32m---> 11\u001b[0m     ret,curr \u001b[38;5;241m=\u001b[39m \u001b[43mcap\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     12\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m ret: \u001b[38;5;28;01mbreak\u001b[39;00m\n\u001b[0;32m     13\u001b[0m     curr \u001b[38;5;241m=\u001b[39m cv2\u001b[38;5;241m.\u001b[39mresize(curr,(W,H))\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "videos_path = 'E:/Datasets/DeepStab_Dataset/stable/'\n",
    "flows_path = 'E:/Datasets/Flows/'\n",
    "cv2.namedWindow('window',cv2.WINDOW_NORMAL)\n",
    "videos = os.listdir(videos_path)\n",
    "for video in tqdm.tqdm(videos):\n",
    "    cap = cv2.VideoCapture(os.path.join(videos_path, video))\n",
    "    ret,prev = cap.read()\n",
    "    prev = cv2.resize(prev,(W,H))\n",
    "    flows = []\n",
    "    while True:\n",
    "        ret,curr = cap.read()\n",
    "        if not ret: break\n",
    "        curr = cv2.resize(curr,(W,H))\n",
    "        flow = get_flow(prev,curr)\n",
    "        flows.append(flow)\n",
    "        prev = curr\n",
    "        cv2.imshow('window',show_flow(flow))\n",
    "        if cv2.waitKey(1) & 0xFF == ord('9'):\n",
    "            break\n",
    "    flows = np.array(flows).astype(np.float32)\n",
    "    output_path = os.path.join(flows_path,video.split('.')[0] + '.npy')\n",
    "    np.save(output_path,flows)\n",
    "cv2.destroyAllWindows()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "DUTCode",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
